package com.cml.batisext.core.repository.mybatis.base.sqlprovider;

import java.lang.reflect.Field;
import java.util.Map;

import com.cml.batisext.core.bean.base.DatabaseBean;
import com.cml.batisext.core.repository.mybatis.annotations.DbField;
import com.cml.batisext.core.repository.mybatis.annotations.DbTable;
import com.cml.batisext.core.repository.mybatis.base.BaseDao;
import com.cml.batisext.core.repository.mybatis.base.BaseDaoSqlProvider;

/**
 * 通用接口的update方法的sql生成器
 * @author lwh
 *
 */
public class BaseDaoUpdateSqlProvider extends BaseDaoSqlProvider {

	// 指定列更新
	public String doBuildSql(Map<String, Object> params,Class<?> beanType) throws Exception {
		
		String tableName = beanType.getAnnotation(DbTable.class).table();

		Object beanParam = params.get(BaseDao.ENTRY_PARAM_NAME);
		String[] columnParam = (String[]) params
				.get(BaseDao.COLUMNS_PARAM_NAME);

		String tmpColumnName = null;
		Class<?> tmpColumnType = null;
		String sqlSetField = "  ";
		Object fieldValue = null;
		Field tmpField = null;
		for (String columnName : columnParam) {
			tmpField = beanType.getDeclaredField(columnName);

			DbField field = tmpField.getAnnotation(DbField.class);
			// 只对标注了为数据库字段的属性进行处理
			if (field == null) {
				throw new Exception("字段" + columnName + "是非数据库字段，不能update");
			}

			tmpColumnName = getColumnName(tmpField, field);
			tmpColumnType = getColumnType(tmpField, field);

			fieldValue = getColumnValue(beanParam, tmpField, beanType);

			if (fieldValue != null) {
				sqlSetField += tmpColumnName + "=" + setupFieldValue(tmpColumnType,fieldValue) + ",";
			}

		}

		Long pk = getPKValue(beanParam);
		if (pk == null) {
			throw new Exception("传入bean中id字段必须有值");
		}

		Object versionValue = DatabaseBean.class.getMethod("getVersion").invoke(beanParam);
		
		DatabaseBean.class.getMethod("setVersion",Integer.class).invoke(beanParam,(Integer)versionValue+1);
		
		String sql = "update " + tableName + " set "
				+ sqlSetField + "version = version + 1"
				+ " where id = " + pk + " and version = " + versionValue;
		return sql;
	}

	@Override
	public String doBuildSql(Object param, Class<?> beanType) throws Exception {
		
		String tableName = beanType.getAnnotation(DbTable.class).table();
		
		Field beanFields[] = beanType.getDeclaredFields();
		String tmpColumnName = null;
		Class<?> tmpColumnType = null;
		String sqlSetField = "";
		Object fieldValue = null;
		for(Field f : beanFields){
			DbField field = f.getAnnotation(DbField.class);
			//只对标注了为数据库字段的属性进行处理
			if(field == null){
				continue;
			}
			
			tmpColumnName = getColumnName(f, field);
			tmpColumnType = getColumnType(f, field);
			fieldValue = getColumnValue(param,f,beanType);
			
			if(fieldValue != null){
				sqlSetField += tmpColumnName + "=" + setupFieldValue(tmpColumnType,fieldValue) + ",";
			}
		}
		Long pk = getPKValue(param);
		if (pk == null) {
			throw new Exception("传入bean中id字段必须有值");
		}

		Object versionValue = DatabaseBean.class.getMethod("getVersion").invoke(param);
		DatabaseBean.class.getMethod("setVersion",Integer.class).invoke(param,(Integer)versionValue+1);
		
		String sql = "update " + tableName + " set "
				+ sqlSetField + "version = version + 1"
				+ " where id = " + pk + " and version = " + versionValue;
		return sql;
	}
	
	
}
